{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/finetuning/cross_encoder_finetuning/cross_encoder_finetuning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to Finetune a cross-encoder using LLamaIndex"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-finetuning-cross-encoders\n",
    "%pip install llama-index-llms-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m11.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.0/302.0 kB\u001b[0m \u001b[31m25.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.0/86.0 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.7/7.7 MB\u001b[0m \u001b[31m42.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m43.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m52.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m58.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m27.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Building wheel for sentence-transformers (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.0/77.0 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h"
     ]
    }
   ],
   "source": [
    "# Download Requirements\n",
    "!pip install datasets --quiet\n",
    "!pip install sentence-transformers --quiet\n",
    "!pip install openai --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process\n",
    "\n",
    "- Download the QASPER Dataset from HuggingFace Hub using Datasets Library (https://huggingface.co/datasets/allenai/qasper)\n",
    "\n",
    "- From the train and test splits of the dataset extract 800 and 80 samples respectively\n",
    "\n",
    "- Use the 800 samples collected from train data which have the respective questions framed on a research paper to generate a dataset in the respective format required for CrossEncoder finetuning. Currently the format we use is that a single sample of fine tune data consists of two sentences(question and context) and a score either 0 or 1 where 1 shows that the question and context are relevant to each other and 0 shows they are not relevant to each other.\n",
    "\n",
    "- Use the 100 samples of test set to extract two kinds of evaluation datasets\n",
    "  * Rag Eval Dataset:-One dataset consists of samples where a single sample consists of a research paper content, list of questions on the research paper, answers of the list of questions on the research paper. While forming this dataset we keep only questions which have long answers/ free-form answers for better comparision with RAG generated answers.\n",
    "\n",
    "  * Reranking Eval Dataset:- The other datasets consists of samples where a single sample consists of the research paper content, list of questions on the research paper, list of contexts from the research paper contents relevant to each question\n",
    "\n",
    "- We finetuned the cross-encoder using helper utilities written in llamaindex and push it to HuggingFace Hub using the huggingface cli tokens login which can be found here:- https://huggingface.co/settings/tokens\n",
    "\n",
    "- We evaluate on both datasets using two metrics and three cases\n",
    "     1. Just OpenAI embeddings without any reranker\n",
    "     2. OpenAI embeddings combined with cross-encoder/ms-marco-MiniLM-L-12-v2 as reranker\n",
    "     3. OpenAI embeddings combined with our fine-tuned cross encoder model as reranker\n",
    "\n",
    "* Evaluation Criteria for each Eval Dataset\n",
    "  - Hits metric:- For evaluating the Reranking Eval Dataset we just simply use the retriever+ post-processor functionalities of LLamaIndex to see in the different cases how many times does the relevant context gets retrieved and call it the hits metric.\n",
    "\n",
    "  - Pairwise Comparision Evaluator:- We use the Pairwise Comparision Evaluator provided by LLamaIndex (https://github.com/run-llama/llama_index/blob/main/llama_index/evaluation/pairwise.py) to compare the responses of the respective query engines created in each case with the reference free-form answers provided.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import random\n",
    "\n",
    "\n",
    "# Download QASPER dataset from HuggingFace https://huggingface.co/datasets/allenai/qasper\n",
    "dataset = load_dataset(\"allenai/qasper\")\n",
    "\n",
    "# Split the dataset into train, validation, and test splits\n",
    "train_dataset = dataset[\"train\"]\n",
    "validation_dataset = dataset[\"validation\"]\n",
    "test_dataset = dataset[\"test\"]\n",
    "\n",
    "random.seed(42)  # Set a random seed for reproducibility\n",
    "\n",
    "# Randomly sample 800 rows from the training split\n",
    "train_sampled_indices = random.sample(range(len(train_dataset)), 800)\n",
    "train_samples = [train_dataset[i] for i in train_sampled_indices]\n",
    "\n",
    "\n",
    "# Randomly sample 100 rows from the test split\n",
    "test_sampled_indices = random.sample(range(len(test_dataset)), 80)\n",
    "test_samples = [test_dataset[i] for i in test_sampled_indices]\n",
    "\n",
    "# Now we have 800 research papers for training and 80 research papers to evaluate on"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## QASPER Dataset\n",
    "* Each row has the below 6 columns\n",
    "    - id: Unique identifier of the research paper\n",
    "\n",
    "    - title: Title of the Research paper\n",
    "\n",
    "    - abstract: Abstract of the research paper\n",
    "\n",
    "    - full_text: full text of the research paper\n",
    "\n",
    "    - qas: Questions and answers pertaining to each research paper\n",
    "\n",
    "    - figures_and_tables: figures and tables of each research paper\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get full text paper data , questions on the paper from training samples of QASPER to generate training dataset for cross-encoder finetuning\n",
    "from typing import List\n",
    "\n",
    "\n",
    "# Utility function to get full-text of the research papers from the dataset\n",
    "def get_full_text(sample: dict) -> str:\n",
    "    \"\"\"\n",
    "    :param dict sample: the row sample from QASPER\n",
    "    \"\"\"\n",
    "    title = sample[\"title\"]\n",
    "    abstract = sample[\"abstract\"]\n",
    "    sections_list = sample[\"full_text\"][\"section_name\"]\n",
    "    paragraph_list = sample[\"full_text\"][\"paragraphs\"]\n",
    "    combined_sections_with_paras = \"\"\n",
    "    if len(sections_list) == len(paragraph_list):\n",
    "        combined_sections_with_paras += title + \"\\t\"\n",
    "        combined_sections_with_paras += abstract + \"\\t\"\n",
    "        for index in range(0, len(sections_list)):\n",
    "            combined_sections_with_paras += str(sections_list[index]) + \"\\t\"\n",
    "            combined_sections_with_paras += \"\".join(paragraph_list[index])\n",
    "        return combined_sections_with_paras\n",
    "\n",
    "    else:\n",
    "        print(\"Not the same number of sections as paragraphs list\")\n",
    "\n",
    "\n",
    "# utility function to extract list of questions from the dataset\n",
    "def get_questions(sample: dict) -> List[str]:\n",
    "    \"\"\"\n",
    "    :param dict sample: the row sample from QASPER\n",
    "    \"\"\"\n",
    "    questions_list = sample[\"qas\"][\"question\"]\n",
    "    return questions_list\n",
    "\n",
    "\n",
    "doc_qa_dict_list = []\n",
    "\n",
    "for train_sample in train_samples:\n",
    "    full_text = get_full_text(train_sample)\n",
    "    questions_list = get_questions(train_sample)\n",
    "    local_dict = {\"paper\": full_text, \"questions\": questions_list}\n",
    "    doc_qa_dict_list.append(local_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "800"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(doc_qa_dict_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save training data as a csv\n",
    "import pandas as pd\n",
    "\n",
    "df_train = pd.DataFrame(doc_qa_dict_list)\n",
    "df_train.to_csv(\"train.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate RAG Eval test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get evaluation data papers , questions and answers\n",
    "\"\"\"\n",
    "The Answers field in the dataset follow the below format:-\n",
    "Unanswerable answers have \"unanswerable\" set to true.\n",
    "\n",
    "The remaining answers have exactly one of the following fields being non-empty.\n",
    "\n",
    "\"extractive_spans\" are spans in the paper which serve as the answer.\n",
    "\"free_form_answer\" is a written out answer.\n",
    "\"yes_no\" is true iff the answer is Yes, and false iff the answer is No.\n",
    "\n",
    "We accept only free-form answers and for all the other kind of answers we set their value to 'Unacceptable',\n",
    "to better evaluate the performance of the query engine using pairwise comparison evaluator as it uses GPT-4 which is biased towards preferring long answers more.\n",
    "https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1\n",
    "\n",
    "So in the case of 'yes_no' answers it can favour Query Engine answers more than reference answers.\n",
    "Also in the case of extracted spans it can favour reference answers more than Query engine generated answers.\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "eval_doc_qa_answer_list = []\n",
    "\n",
    "\n",
    "# Utility function to extract answers from the dataset\n",
    "def get_answers(sample: dict) -> List[str]:\n",
    "    \"\"\"\n",
    "    :param dict sample: the row sample from the train split of QASPER\n",
    "    \"\"\"\n",
    "    final_answers_list = []\n",
    "    answers = sample[\"qas\"][\"answers\"]\n",
    "    for answer in answers:\n",
    "        local_answer = \"\"\n",
    "        types_of_answers = answer[\"answer\"][0]\n",
    "        if types_of_answers[\"unanswerable\"] == False:\n",
    "            if types_of_answers[\"free_form_answer\"] != \"\":\n",
    "                local_answer = types_of_answers[\"free_form_answer\"]\n",
    "            else:\n",
    "                local_answer = \"Unacceptable\"\n",
    "        else:\n",
    "            local_answer = \"Unacceptable\"\n",
    "\n",
    "        final_answers_list.append(local_answer)\n",
    "\n",
    "    return final_answers_list\n",
    "\n",
    "\n",
    "for test_sample in test_samples:\n",
    "    full_text = get_full_text(test_sample)\n",
    "    questions_list = get_questions(test_sample)\n",
    "    answers_list = get_answers(test_sample)\n",
    "    local_dict = {\n",
    "        \"paper\": full_text,\n",
    "        \"questions\": questions_list,\n",
    "        \"answers\": answers_list,\n",
    "    }\n",
    "    eval_doc_qa_answer_list.append(local_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "80\n"
     ]
    }
   ],
   "source": [
    "len(eval_doc_qa_answer_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save eval data as a csv\n",
    "import pandas as pd\n",
    "\n",
    "df_test = pd.DataFrame(eval_doc_qa_answer_list)\n",
    "df_test.to_csv(\"test.csv\")\n",
    "\n",
    "# The Rag Eval test data can be found at the below dropbox link\n",
    "# https://www.dropbox.com/scl/fi/3lmzn6714oy358mq0vawm/test.csv?rlkey=yz16080te4van7fvnksi9kaed&dl=0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Finetuning Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the latest version of llama-index\n",
    "!pip install llama-index --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the respective training dataset from the initial train data collected from QASPER in the format required by\n",
    "import os\n",
    "from llama_index.core import SimpleDirectoryReader\n",
    "import openai\n",
    "from llama_index.finetuning.cross_encoders.dataset_gen import (\n",
    "    generate_ce_fine_tuning_dataset,\n",
    "    generate_synthetic_queries_over_documents,\n",
    ")\n",
    "\n",
    "from llama_index.finetuning.cross_encoders import CrossEncoderFinetuneEngine\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Document\n",
    "\n",
    "final_finetuning_data_list = []\n",
    "for paper in doc_qa_dict_list:\n",
    "    questions_list = paper[\"questions\"]\n",
    "    documents = [Document(text=paper[\"paper\"])]\n",
    "    local_finetuning_dataset = generate_ce_fine_tuning_dataset(\n",
    "        documents=documents,\n",
    "        questions_list=questions_list,\n",
    "        max_chunk_length=256,\n",
    "        top_k=5,\n",
    "    )\n",
    "    final_finetuning_data_list.extend(local_finetuning_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11674"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Total samples in the final fine-tuning dataset\n",
    "len(final_finetuning_data_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save final fine-tuning dataset\n",
    "import pandas as pd\n",
    "\n",
    "df_finetuning_dataset = pd.DataFrame(final_finetuning_data_list)\n",
    "df_finetuning_dataset.to_csv(\"fine_tuning.csv\")\n",
    "\n",
    "# The finetuning dataset can be found at the below dropbox link:-\n",
    "# https://www.dropbox.com/scl/fi/zu6vtisp1j3wg2hbje5xv/fine_tuning.csv?rlkey=0jr6fud8sqk342agfjbzvwr9x&dl=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load fine-tuning dataset\n",
    "\n",
    "finetuning_dataset = final_finetuning_data_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CrossEncoderFinetuningDatasetSample(query='Do they repot results only on English data?', context='addition to precision, recall, and F1 scores for both tasks, we show the average of the F1 scores across both tasks. On the ADE dataset, we achieve SOTA results for both the NER and RE tasks. On the CoNLL04 dataset, we achieve SOTA results on the NER task, while our performance on the RE task is competitive with other recent models. On both datasets, we achieve SOTA results when considering the average F1 score across both tasks. The largest gain relative to the previous SOTA performance is on the RE task of the ADE dataset, where we see an absolute improvement of 4.5 on the macro-average F1 score.While the model of Eberts and Ulges eberts2019span outperforms our proposed architecture on the CoNLL04 RE task, their results come at the cost of greater model complexity. As mentioned above, Eberts and Ulges fine-tune the BERTBASE model, which has 110 million trainable parameters. In contrast, given the hyperparameters used for final training on the CoNLL04 dataset, our proposed architecture has approximately 6 million trainable parameters.The fact that the optimal number of task-specific layers differed between the two datasets demonstrates the', score=0)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetuning_dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Reranking Eval test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download RAG Eval test data\n",
    "!wget -O test.csv https://www.dropbox.com/scl/fi/3lmzn6714oy358mq0vawm/test.csv?rlkey=yz16080te4van7fvnksi9kaed&dl=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of papers in the test sample:- 80\n"
     ]
    }
   ],
   "source": [
    "# Generate Reranking Eval Dataset from the Eval data\n",
    "import pandas as pd\n",
    "import ast  # Used to safely evaluate the string as a list\n",
    "\n",
    "# Load Eval Data\n",
    "df_test = pd.read_csv(\"/content/test.csv\", index_col=0)\n",
    "\n",
    "df_test[\"questions\"] = df_test[\"questions\"].apply(ast.literal_eval)\n",
    "df_test[\"answers\"] = df_test[\"answers\"].apply(ast.literal_eval)\n",
    "print(f\"Number of papers in the test sample:- {len(df_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Document\n",
    "\n",
    "final_eval_data_list = []\n",
    "for index, row in df_test.iterrows():\n",
    "    documents = [Document(text=row[\"paper\"])]\n",
    "    query_list = row[\"questions\"]\n",
    "    local_eval_dataset = generate_ce_fine_tuning_dataset(\n",
    "        documents=documents,\n",
    "        questions_list=query_list,\n",
    "        max_chunk_length=256,\n",
    "        top_k=5,\n",
    "    )\n",
    "    relevant_query_list = []\n",
    "    relevant_context_list = []\n",
    "\n",
    "    for item in local_eval_dataset:\n",
    "        if item.score == 1:\n",
    "            relevant_query_list.append(item.query)\n",
    "            relevant_context_list.append(item.context)\n",
    "\n",
    "    if len(relevant_query_list) > 0:\n",
    "        final_eval_data_list.append(\n",
    "            {\n",
    "                \"paper\": row[\"paper\"],\n",
    "                \"questions\": relevant_query_list,\n",
    "                \"context\": relevant_context_list,\n",
    "            }\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Length of Reranking Eval Dataset\n",
    "len(final_eval_data_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save Reranking eval dataset\n",
    "import pandas as pd\n",
    "\n",
    "df_finetuning_dataset = pd.DataFrame(final_eval_data_list)\n",
    "df_finetuning_dataset.to_csv(\"reranking_test.csv\")\n",
    "\n",
    "# The reranking dataset can be found at the below dropbox link\n",
    "# https://www.dropbox.com/scl/fi/mruo5rm46k1acm1xnecev/reranking_test.csv?rlkey=hkniwowq0xrc3m0ywjhb2gf26&dl=0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune Cross-Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install huggingface_hub --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ed078bb7d4e49678ecfa42dc06a2398",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91f37d51eceb442885a371db97cf3381",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c331e837ed604a9ba04acdd723e8ea89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1460 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f6f2a6f61ad48c1a97bd2a5eb0bc26f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/1460 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Initialise the cross-encoder fine-tuning engine\n",
    "finetuning_engine = CrossEncoderFinetuneEngine(\n",
    "    dataset=finetuning_dataset, epochs=2, batch_size=8\n",
    ")\n",
    "\n",
    "# Finetune the cross-encoder model\n",
    "finetuning_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d836ead969d49d2b35a56483bf09889",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Push model to HuggingFace Hub\n",
    "finetuning_engine.push_to_hub(\n",
    "    repo_id=\"bpHigh/Cross-Encoder-LLamaIndex-Demo-v2\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reranking Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install nest-asyncio --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# attach to the same event-loop\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-10-12 04:47:18--  https://www.dropbox.com/scl/fi/mruo5rm46k1acm1xnecev/reranking_test.csv?rlkey=hkniwowq0xrc3m0ywjhb2gf26\n",
      "Resolving www.dropbox.com (www.dropbox.com)... 162.125.85.18, 2620:100:6035:18::a27d:5512\n",
      "Connecting to www.dropbox.com (www.dropbox.com)|162.125.85.18|:443... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://uc414efe80c7598407c86166866d.dl.dropboxusercontent.com/cd/0/inline/CFcxAwrNZkpcZLmEipK-DxnJF6BKMu8rKmoRp-FUoqRF83K1t0kG0OzBliY-8E7EmbRqkkRZENO4ayEUPgul8lzY7iyARc7kauQ4iHdGps9_Y4jHyuLstzxbVT1TDQyhotVUYWZ9uHNmDHI9UFWAKBVm/file# [following]\n",
      "--2023-10-12 04:47:18--  https://uc414efe80c7598407c86166866d.dl.dropboxusercontent.com/cd/0/inline/CFcxAwrNZkpcZLmEipK-DxnJF6BKMu8rKmoRp-FUoqRF83K1t0kG0OzBliY-8E7EmbRqkkRZENO4ayEUPgul8lzY7iyARc7kauQ4iHdGps9_Y4jHyuLstzxbVT1TDQyhotVUYWZ9uHNmDHI9UFWAKBVm/file\n",
      "Resolving uc414efe80c7598407c86166866d.dl.dropboxusercontent.com (uc414efe80c7598407c86166866d.dl.dropboxusercontent.com)... 162.125.80.15, 2620:100:6035:15::a27d:550f\n",
      "Connecting to uc414efe80c7598407c86166866d.dl.dropboxusercontent.com (uc414efe80c7598407c86166866d.dl.dropboxusercontent.com)|162.125.80.15|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 967072 (944K) [text/plain]\n",
      "Saving to: ‘reranking_test.csv’\n",
      "\n",
      "reranking_test.csv  100%[===================>] 944.41K  3.55MB/s    in 0.3s    \n",
      "\n",
      "2023-10-12 04:47:19 (3.55 MB/s) - ‘reranking_test.csv’ saved [967072/967072]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Download Reranking test data\n",
    "!wget -O reranking_test.csv https://www.dropbox.com/scl/fi/mruo5rm46k1acm1xnecev/reranking_test.csv?rlkey=hkniwowq0xrc3m0ywjhb2gf26&dl=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of papers in the reranking eval dataset:- 38\n"
     ]
    }
   ],
   "source": [
    "# Load Reranking Dataset\n",
    "import pandas as pd\n",
    "import ast\n",
    "\n",
    "df_reranking = pd.read_csv(\"/content/reranking_test.csv\", index_col=0)\n",
    "df_reranking[\"questions\"] = df_reranking[\"questions\"].apply(ast.literal_eval)\n",
    "df_reranking[\"context\"] = df_reranking[\"context\"].apply(ast.literal_eval)\n",
    "print(f\"Number of papers in the reranking eval dataset:- {len(df_reranking)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-e1282d93-cd7a-4536-a8a7-4f4ac8db179b\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>paper</th>\n",
       "      <th>questions</th>\n",
       "      <th>context</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Identifying Condition-Action Statements in Med...</td>\n",
       "      <td>[What supervised machine learning models do th...</td>\n",
       "      <td>[Identifying Condition-Action Statements in Me...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e1282d93-cd7a-4536-a8a7-4f4ac8db179b')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-e1282d93-cd7a-4536-a8a7-4f4ac8db179b button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-e1282d93-cd7a-4536-a8a7-4f4ac8db179b');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                               paper  \\\n",
       "0  Identifying Condition-Action Statements in Med...   \n",
       "\n",
       "                                           questions  \\\n",
       "0  [What supervised machine learning models do th...   \n",
       "\n",
       "                                             context  \n",
       "0  [Identifying Condition-Action Statements in Me...  "
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_reranking.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34cf70bc3dbb48e2b1f1cf74836ec442",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/854 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31993884af454bfa835dbaec8d0a0be1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bf39b437040f44af8bf41bf3d4a38a26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42fbc3d09200448f81651b5ddcd5e773",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "184ef6ea433747e9b2db933613db71ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "447733e06f7347a181624c40e859e46d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# We evaluate by calculating hits for each (question, context) pair,\n",
    "# we retrieve top-k documents with the question, and\n",
    "# it’s a hit if the results contain the context\n",
    "from llama_index.core.postprocessor import SentenceTransformerRerank\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Response\n",
    "from llama_index.core.retrievers import VectorIndexRetriever\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core import Document\n",
    "from llama_index.core import Settings\n",
    "\n",
    "import os\n",
    "import openai\n",
    "import pandas as pd\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "Settings.chunk_size = 256\n",
    "\n",
    "rerank_base = SentenceTransformerRerank(\n",
    "    model=\"cross-encoder/ms-marco-MiniLM-L-12-v2\", top_n=3\n",
    ")\n",
    "\n",
    "rerank_finetuned = SentenceTransformerRerank(\n",
    "    model=\"bpHigh/Cross-Encoder-LLamaIndex-Demo-v2\", top_n=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "without_reranker_hits = 0\n",
    "base_reranker_hits = 0\n",
    "finetuned_reranker_hits = 0\n",
    "total_number_of_context = 0\n",
    "for index, row in df_reranking.iterrows():\n",
    "    documents = [Document(text=row[\"paper\"])]\n",
    "    query_list = row[\"questions\"]\n",
    "    context_list = row[\"context\"]\n",
    "\n",
    "    assert len(query_list) == len(context_list)\n",
    "    vector_index = VectorStoreIndex.from_documents(documents)\n",
    "\n",
    "    retriever_without_reranker = vector_index.as_query_engine(\n",
    "        similarity_top_k=3, response_mode=\"no_text\"\n",
    "    )\n",
    "    retriever_with_base_reranker = vector_index.as_query_engine(\n",
    "        similarity_top_k=8,\n",
    "        response_mode=\"no_text\",\n",
    "        node_postprocessors=[rerank_base],\n",
    "    )\n",
    "    retriever_with_finetuned_reranker = vector_index.as_query_engine(\n",
    "        similarity_top_k=8,\n",
    "        response_mode=\"no_text\",\n",
    "        node_postprocessors=[rerank_finetuned],\n",
    "    )\n",
    "\n",
    "    for index in range(0, len(query_list)):\n",
    "        query = query_list[index]\n",
    "        context = context_list[index]\n",
    "        total_number_of_context += 1\n",
    "\n",
    "        response_without_reranker = retriever_without_reranker.query(query)\n",
    "        without_reranker_nodes = response_without_reranker.source_nodes\n",
    "\n",
    "        for node in without_reranker_nodes:\n",
    "            if context in node.node.text or node.node.text in context:\n",
    "                without_reranker_hits += 1\n",
    "\n",
    "        response_with_base_reranker = retriever_with_base_reranker.query(query)\n",
    "        with_base_reranker_nodes = response_with_base_reranker.source_nodes\n",
    "\n",
    "        for node in with_base_reranker_nodes:\n",
    "            if context in node.node.text or node.node.text in context:\n",
    "                base_reranker_hits += 1\n",
    "\n",
    "        response_with_finetuned_reranker = (\n",
    "            retriever_with_finetuned_reranker.query(query)\n",
    "        )\n",
    "        with_finetuned_reranker_nodes = (\n",
    "            response_with_finetuned_reranker.source_nodes\n",
    "        )\n",
    "\n",
    "        for node in with_finetuned_reranker_nodes:\n",
    "            if context in node.node.text or node.node.text in context:\n",
    "                finetuned_reranker_hits += 1\n",
    "\n",
    "        assert (\n",
    "            len(with_finetuned_reranker_nodes)\n",
    "            == len(with_base_reranker_nodes)\n",
    "            == len(without_reranker_nodes)\n",
    "            == 3\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results\n",
    "\n",
    "As we can see below we get more hits with finetuned_cross_encoder compared to other options."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Metric</th>\n",
       "      <th>OpenAI_Embeddings</th>\n",
       "      <th>Base_cross_encoder</th>\n",
       "      <th>Finetuned_cross_encoder</th>\n",
       "      <th>Total Relevant Context</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Hits</td>\n",
       "      <td>30</td>\n",
       "      <td>34</td>\n",
       "      <td>37</td>\n",
       "      <td>85</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Metric  OpenAI_Embeddings  Base_cross_encoder  Finetuned_cross_encoder  \\\n",
       "0   Hits                 30                  34                       37   \n",
       "\n",
       "   Total Relevant Context  \n",
       "0                      85  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "without_reranker_scores = [without_reranker_hits]\n",
    "base_reranker_scores = [base_reranker_hits]\n",
    "finetuned_reranker_scores = [finetuned_reranker_hits]\n",
    "reranker_eval_dict = {\n",
    "    \"Metric\": \"Hits\",\n",
    "    \"OpenAI_Embeddings\": without_reranker_scores,\n",
    "    \"Base_cross_encoder\": base_reranker_scores,\n",
    "    \"Finetuned_cross_encoder\": finetuned_reranker_hits,\n",
    "    \"Total Relevant Context\": total_number_of_context,\n",
    "}\n",
    "df_reranker_eval_results = pd.DataFrame(reranker_eval_dict)\n",
    "display(df_reranker_eval_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RAG Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-10-12 04:47:36--  https://www.dropbox.com/scl/fi/3lmzn6714oy358mq0vawm/test.csv?rlkey=yz16080te4van7fvnksi9kaed\n",
      "Resolving www.dropbox.com (www.dropbox.com)... 162.125.85.18, 2620:100:6035:18::a27d:5512\n",
      "Connecting to www.dropbox.com (www.dropbox.com)|162.125.85.18|:443... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com/cd/0/inline/CFfI9UezsVwFpN4CHgYrSFveuNE01DfczDaeFGZO-Ud5VdDRff1LNG7hEhkBZwVljuRde-EZU336ASpnZs32qVePvpQEFnKB2SeplFpMt50G0m5IZepyV6pYPbNAhm0muYE_rjhlolHxRUQP_iaJBX9z/file# [following]\n",
      "--2023-10-12 04:47:38--  https://ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com/cd/0/inline/CFfI9UezsVwFpN4CHgYrSFveuNE01DfczDaeFGZO-Ud5VdDRff1LNG7hEhkBZwVljuRde-EZU336ASpnZs32qVePvpQEFnKB2SeplFpMt50G0m5IZepyV6pYPbNAhm0muYE_rjhlolHxRUQP_iaJBX9z/file\n",
      "Resolving ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com (ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com)... 162.125.80.15, 2620:100:6035:15::a27d:550f\n",
      "Connecting to ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com (ucb6087b1b853dad24e8201987fc.dl.dropboxusercontent.com)|162.125.80.15|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1821706 (1.7M) [text/plain]\n",
      "Saving to: ‘test.csv’\n",
      "\n",
      "test.csv            100%[===================>]   1.74M  6.37MB/s    in 0.3s    \n",
      "\n",
      "2023-10-12 04:47:38 (6.37 MB/s) - ‘test.csv’ saved [1821706/1821706]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Download RAG Eval test data\n",
    "!wget -O test.csv https://www.dropbox.com/scl/fi/3lmzn6714oy358mq0vawm/test.csv?rlkey=yz16080te4van7fvnksi9kaed&dl=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of papers in the test sample:- 80\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import ast  # Used to safely evaluate the string as a list\n",
    "\n",
    "# Load Eval Data\n",
    "df_test = pd.read_csv(\"/content/test.csv\", index_col=0)\n",
    "\n",
    "df_test[\"questions\"] = df_test[\"questions\"].apply(ast.literal_eval)\n",
    "df_test[\"answers\"] = df_test[\"answers\"].apply(ast.literal_eval)\n",
    "print(f\"Number of papers in the test sample:- {len(df_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-8dd2786b-981d-4642-b009-9531bd14adde\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>paper</th>\n",
       "      <th>questions</th>\n",
       "      <th>answers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Identifying Condition-Action Statements in Med...</td>\n",
       "      <td>[What supervised machine learning models do th...</td>\n",
       "      <td>[Unacceptable, Unacceptable, 1470 sentences, U...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8dd2786b-981d-4642-b009-9531bd14adde')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-8dd2786b-981d-4642-b009-9531bd14adde button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-8dd2786b-981d-4642-b009-9531bd14adde');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                               paper  \\\n",
       "0  Identifying Condition-Action Statements in Med...   \n",
       "\n",
       "                                           questions  \\\n",
       "0  [What supervised machine learning models do th...   \n",
       "\n",
       "                                             answers  \n",
       "0  [Unacceptable, Unacceptable, 1470 sentences, U...  "
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Look at one sample of eval data which has a research paper questions on it and the respective reference answers\n",
    "df_test.head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline Evaluation\n",
    "\n",
    "Just using OpenAI Embeddings for retrieval without any re-ranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Eval Method:-\n",
    "1. Iterate over each row of the test dataset:-\n",
    "    1. For the current row being iterated, create a vector index using the paper document provided in the paper column of the dataset\n",
    "    2. Query the vector index with a top_k value of top 3 nodes without any reranker\n",
    "    3. Compare the generated answers with the reference answers of the respective sample using Pairwise Comparison Evaluator and add the scores to a list\n",
    "5. Repeat 1 until all the rows have been iterated\n",
    "6. Calculate avg scores over all samples/ rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Response\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.evaluation import PairwiseComparisonEvaluator\n",
    "from llama_index.core.evaluation.eval_utils import (\n",
    "    get_responses,\n",
    "    get_results_df,\n",
    ")\n",
    "\n",
    "import os\n",
    "import openai\n",
    "import pandas as pd\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "gpt4 = OpenAI(temperature=0, model=\"gpt-4\")\n",
    "\n",
    "evaluator_gpt4_pairwise = PairwiseComparisonEvaluator(llm=gpt4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_scores_list = []\n",
    "\n",
    "no_reranker_dict_list = []\n",
    "\n",
    "\n",
    "# Iterate over the rows of the dataset\n",
    "for index, row in df_test.iterrows():\n",
    "    documents = [Document(text=row[\"paper\"])]\n",
    "    query_list = row[\"questions\"]\n",
    "    reference_answers_list = row[\"answers\"]\n",
    "    number_of_accepted_queries = 0\n",
    "    # Create vector index for the current row being iterated\n",
    "    vector_index = VectorStoreIndex.from_documents(documents)\n",
    "\n",
    "    # Query the vector index with a top_k value of top 3 documents without any reranker\n",
    "    query_engine = vector_index.as_query_engine(similarity_top_k=3)\n",
    "\n",
    "    assert len(query_list) == len(reference_answers_list)\n",
    "    pairwise_local_score = 0\n",
    "\n",
    "    for index in range(0, len(query_list)):\n",
    "        query = query_list[index]\n",
    "        reference = reference_answers_list[index]\n",
    "\n",
    "        if reference != \"Unacceptable\":\n",
    "            number_of_accepted_queries += 1\n",
    "\n",
    "            response = str(query_engine.query(query))\n",
    "\n",
    "            no_reranker_dict = {\n",
    "                \"query\": query,\n",
    "                \"response\": response,\n",
    "                \"reference\": reference,\n",
    "            }\n",
    "            no_reranker_dict_list.append(no_reranker_dict)\n",
    "\n",
    "            # Compare the generated answers with the reference answers of the respective sample using\n",
    "            # Pairwise Comparison Evaluator and add the scores to a list\n",
    "\n",
    "            pairwise_eval_result = await evaluator_gpt4_pairwise.aevaluate(\n",
    "                query, response=response, reference=reference\n",
    "            )\n",
    "\n",
    "            pairwise_score = pairwise_eval_result.score\n",
    "\n",
    "            pairwise_local_score += pairwise_score\n",
    "\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "    if number_of_accepted_queries > 0:\n",
    "        avg_pairwise_local_score = (\n",
    "            pairwise_local_score / number_of_accepted_queries\n",
    "        )\n",
    "        pairwise_scores_list.append(avg_pairwise_local_score)\n",
    "\n",
    "\n",
    "overal_pairwise_average_score = sum(pairwise_scores_list) / len(\n",
    "    pairwise_scores_list\n",
    ")\n",
    "\n",
    "df_responses = pd.DataFrame(no_reranker_dict_list)\n",
    "df_responses.to_csv(\"No_Reranker_Responses.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>pairwise score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Without Reranker</td>\n",
       "      <td>0.553788</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               name  pairwise score\n",
       "0  Without Reranker        0.553788"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_dict = {\n",
    "    \"name\": [\"Without Reranker\"],\n",
    "    \"pairwise score\": [overal_pairwise_average_score],\n",
    "}\n",
    "results_df = pd.DataFrame(results_dict)\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate with base reranker\n",
    "\n",
    "OpenAI Embeddings +  `cross-encoder/ms-marco-MiniLM-L-12-v2` as reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Eval Method:-\n",
    "1. Iterate over each row of the test dataset:-\n",
    "    1. For the current row being iterated, create a vector index using the paper document provided in the paper column of the dataset\n",
    "    2. Query the vector index with a top_k value of top 5 nodes.\n",
    "    3. Use cross-encoder/ms-marco-MiniLM-L-12-v2 as a reranker as a NodePostprocessor to get top_k value of top 3 nodes out of the 8 nodes\n",
    "    3. Compare the generated answers with the reference answers of the respective sample using Pairwise Comparison Evaluator and add the scores to a list\n",
    "5. Repeat 1 until all the rows have been iterated\n",
    "6. Calculate avg scores over all samples/ rows\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4cd72d8ca9ab45548335b59e673c1ab6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/791 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "43fcf99b246e45dcab91746fdad3eb43",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01f043960f8248b48f7db3dfe765bf7b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1c27a842b2964548898ca3f1152756b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ae5a97aea424b4a93d70b7a1e75c7f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from llama_index.core.postprocessor import SentenceTransformerRerank\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Response\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.evaluation import PairwiseComparisonEvaluator\n",
    "import os\n",
    "import openai\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "rerank = SentenceTransformerRerank(\n",
    "    model=\"cross-encoder/ms-marco-MiniLM-L-12-v2\", top_n=3\n",
    ")\n",
    "\n",
    "gpt4 = OpenAI(temperature=0, model=\"gpt-4\")\n",
    "\n",
    "evaluator_gpt4_pairwise = PairwiseComparisonEvaluator(llm=gpt4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_scores_list = []\n",
    "\n",
    "base_reranker_dict_list = []\n",
    "\n",
    "\n",
    "# Iterate over the rows of the dataset\n",
    "for index, row in df_test.iterrows():\n",
    "    documents = [Document(text=row[\"paper\"])]\n",
    "    query_list = row[\"questions\"]\n",
    "    reference_answers_list = row[\"answers\"]\n",
    "\n",
    "    number_of_accepted_queries = 0\n",
    "    # Create vector index for the current row being iterated\n",
    "    vector_index = VectorStoreIndex.from_documents(documents)\n",
    "\n",
    "    # Query the vector index with a top_k value of top 8 nodes with reranker\n",
    "    # as cross-encoder/ms-marco-MiniLM-L-12-v2\n",
    "    query_engine = vector_index.as_query_engine(\n",
    "        similarity_top_k=8, node_postprocessors=[rerank]\n",
    "    )\n",
    "\n",
    "    assert len(query_list) == len(reference_answers_list)\n",
    "    pairwise_local_score = 0\n",
    "\n",
    "    for index in range(0, len(query_list)):\n",
    "        query = query_list[index]\n",
    "        reference = reference_answers_list[index]\n",
    "\n",
    "        if reference != \"Unacceptable\":\n",
    "            number_of_accepted_queries += 1\n",
    "\n",
    "            response = str(query_engine.query(query))\n",
    "\n",
    "            base_reranker_dict = {\n",
    "                \"query\": query,\n",
    "                \"response\": response,\n",
    "                \"reference\": reference,\n",
    "            }\n",
    "            base_reranker_dict_list.append(base_reranker_dict)\n",
    "\n",
    "            # Compare the generated answers with the reference answers of the respective sample using\n",
    "            # Pairwise Comparison Evaluator and add the scores to a list\n",
    "\n",
    "            pairwise_eval_result = await evaluator_gpt4_pairwise.aevaluate(\n",
    "                query=query, response=response, reference=reference\n",
    "            )\n",
    "\n",
    "            pairwise_score = pairwise_eval_result.score\n",
    "\n",
    "            pairwise_local_score += pairwise_score\n",
    "\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "    if number_of_accepted_queries > 0:\n",
    "        avg_pairwise_local_score = (\n",
    "            pairwise_local_score / number_of_accepted_queries\n",
    "        )\n",
    "        pairwise_scores_list.append(avg_pairwise_local_score)\n",
    "\n",
    "overal_pairwise_average_score = sum(pairwise_scores_list) / len(\n",
    "    pairwise_scores_list\n",
    ")\n",
    "\n",
    "df_responses = pd.DataFrame(base_reranker_dict_list)\n",
    "df_responses.to_csv(\"Base_Reranker_Responses.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>pairwise score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>With base cross-encoder/ms-marco-MiniLM-L-12-v...</td>\n",
       "      <td>0.556818</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                name  pairwise score\n",
       "0  With base cross-encoder/ms-marco-MiniLM-L-12-v...        0.556818"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_dict = {\n",
    "    \"name\": [\"With base cross-encoder/ms-marco-MiniLM-L-12-v2 as Reranker\"],\n",
    "    \"pairwise score\": [overal_pairwise_average_score],\n",
    "}\n",
    "results_df = pd.DataFrame(results_dict)\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate with Fine-Tuned re-ranker\n",
    "\n",
    "OpenAI Embeddings + `bpHigh/Cross-Encoder-LLamaIndex-Demo-v2` as reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Eval Method:-\n",
    "1. Iterate over each row of the test dataset:-\n",
    "    1. For the current row being iterated, create a vector index using the paper document provided in the paper column of the dataset\n",
    "    2. Query the vector index with a top_k value of top 5 nodes.\n",
    "    3. Use finetuned version of cross-encoder/ms-marco-MiniLM-L-12-v2 saved as bpHigh/Cross-Encoder-LLamaIndex-Demo as a reranker as a NodePostprocessor to get top_k value of top 3 nodes out of the 8 nodes\n",
    "    3. Compare the generated answers with the reference answers of the respective sample using Pairwise Comparison Evaluator and add the scores to a list\n",
    "5. Repeat 1 until all the rows have been iterated\n",
    "6. Calculate avg scores over all samples/ rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.postprocessor import SentenceTransformerRerank\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Response\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.evaluation import PairwiseComparisonEvaluator\n",
    "import os\n",
    "import openai\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "\n",
    "rerank = SentenceTransformerRerank(\n",
    "    model=\"bpHigh/Cross-Encoder-LLamaIndex-Demo-v2\", top_n=3\n",
    ")\n",
    "\n",
    "\n",
    "gpt4 = OpenAI(temperature=0, model=\"gpt-4\")\n",
    "\n",
    "evaluator_gpt4_pairwise = PairwiseComparisonEvaluator(llm=gpt4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_scores_list = []\n",
    "\n",
    "\n",
    "finetuned_reranker_dict_list = []\n",
    "\n",
    "# Iterate over the rows of the dataset\n",
    "for index, row in df_test.iterrows():\n",
    "    documents = [Document(text=row[\"paper\"])]\n",
    "    query_list = row[\"questions\"]\n",
    "    reference_answers_list = row[\"answers\"]\n",
    "\n",
    "    number_of_accepted_queries = 0\n",
    "    # Create vector index for the current row being iterated\n",
    "    vector_index = VectorStoreIndex.from_documents(documents)\n",
    "\n",
    "    # Query the vector index with a top_k value of top 8 nodes with reranker\n",
    "    # as cross-encoder/ms-marco-MiniLM-L-12-v2\n",
    "    query_engine = vector_index.as_query_engine(\n",
    "        similarity_top_k=8, node_postprocessors=[rerank]\n",
    "    )\n",
    "\n",
    "    assert len(query_list) == len(reference_answers_list)\n",
    "    pairwise_local_score = 0\n",
    "\n",
    "    for index in range(0, len(query_list)):\n",
    "        query = query_list[index]\n",
    "        reference = reference_answers_list[index]\n",
    "\n",
    "        if reference != \"Unacceptable\":\n",
    "            number_of_accepted_queries += 1\n",
    "\n",
    "            response = str(query_engine.query(query))\n",
    "\n",
    "            finetuned_reranker_dict = {\n",
    "                \"query\": query,\n",
    "                \"response\": response,\n",
    "                \"reference\": reference,\n",
    "            }\n",
    "            finetuned_reranker_dict_list.append(finetuned_reranker_dict)\n",
    "\n",
    "            # Compare the generated answers with the reference answers of the respective sample using\n",
    "            # Pairwise Comparison Evaluator and add the scores to a list\n",
    "\n",
    "            pairwise_eval_result = await evaluator_gpt4_pairwise.aevaluate(\n",
    "                query, response=response, reference=reference\n",
    "            )\n",
    "\n",
    "            pairwise_score = pairwise_eval_result.score\n",
    "\n",
    "            pairwise_local_score += pairwise_score\n",
    "\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "    if number_of_accepted_queries > 0:\n",
    "        avg_pairwise_local_score = (\n",
    "            pairwise_local_score / number_of_accepted_queries\n",
    "        )\n",
    "        pairwise_scores_list.append(avg_pairwise_local_score)\n",
    "\n",
    "overal_pairwise_average_score = sum(pairwise_scores_list) / len(\n",
    "    pairwise_scores_list\n",
    ")\n",
    "df_responses = pd.DataFrame(finetuned_reranker_dict_list)\n",
    "df_responses.to_csv(\"Finetuned_Reranker_Responses.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>pairwise score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>With fine-tuned cross-encoder/ms-marco-MiniLM-...</td>\n",
       "      <td>0.6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                name  pairwise score\n",
       "0  With fine-tuned cross-encoder/ms-marco-MiniLM-...             0.6"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_dict = {\n",
    "    \"name\": [\"With fine-tuned cross-encoder/ms-marco-MiniLM-L-12-v2\"],\n",
    "    \"pairwise score\": [overal_pairwise_average_score],\n",
    "}\n",
    "results_df = pd.DataFrame(results_dict)\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results\n",
    "\n",
    "As we can see we get the highest pairwise score with finetuned cross-encoder.\n",
    "\n",
    "Although I would like to point that the reranking eval based on hits is a more robust metric compared to pairwise comparision evaluator as I have seen inconsistencies with the scores and there are also many inherent biases present when evaluating using GPT-4"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
